"""Custom SQLAlchemy types for use with the Annotations API database."""

import base64
import binascii
import uuid

from sqlalchemy import types
from sqlalchemy.dialects import postgresql
from sqlalchemy.exc import DontWrapMixin


class InvalidUUID(Exception, DontWrapMixin):
    pass


class URLSafeUUID(
    types.TypeDecorator
):  # pylint:disable=abstract-method,too-many-ancestors
    """
    Expose UUIDs as URL-safe base64-encoded strings.

    Fields decorated with this type decorator use PostgreSQL UUID fields for
    storage, but expose URL-safe strings in the application.

    This type decorator will handle the transformation between any UUID and a
    URL-safe, base64-encoded string version of that UUID (which will be 22
    characters long). In addition, it will transparently map post-v1.4
    ElasticSearch flake IDs (which are 20 characters long and map to 15 bytes
    of data).
    """

    impl = postgresql.UUID
    cache_ok = True

    def __init__(self):
        # We handle the UUID conversion, explicitly not use as_uuid
        super().__init__(as_uuid=False)

    def process_bind_param(self, value, dialect):
        return self.url_safe_to_hex(value)

    def process_result_value(self, value, dialect):
        return self.hex_to_url_safe(value)

    @classmethod
    def url_safe_to_hex(cls, value):
        """
        Return the hex version of the given URL-safe UUID.

        Converts UUID's from the application-level URL-safe format to the hex
        format that's used internally in the DB.
        """
        if value is None:
            return None

        if not isinstance(value, str):
            raise InvalidUUID(f"`url_safe` is {type(value)}, expected str")

        byte_str = value.encode()

        try:
            hex_str = binascii.hexlify(
                base64.urlsafe_b64decode(byte_str + b"==")
            ).decode()
        except binascii.Error as err:
            raise InvalidUUID(f"{value!r} is not a valid encoded UUID") from err

        lengths = len(byte_str), len(hex_str)

        if lengths == (22, 32):  # A normal UUID
            return hex_str

        if lengths == (20, 30):  # ElasticSearch flake ID
            return cls._add_magic_byte(hex_str)

        raise InvalidUUID(f"{value!r} is not a valid encoded UUID")

    @classmethod
    def hex_to_url_safe(cls, value):
        """
        Return the URL-safe version of the given hex-format UUID.

        Converts UUID's from the database-internal hex format to the URL-safe
        format that's used in the application.
        """
        if value is None:
            return None

        # Validate and normalise hex string
        hex_str = uuid.UUID(hex=value).hex

        if cls._has_magic_byte(hex_str):  # ElasticSearch flake ID
            hex_str = cls._remove_magic_byte(hex_str)

        # Encode UUID bytes and strip any padding
        data = binascii.unhexlify(hex_str)
        return base64.urlsafe_b64encode(data).decode().rstrip("=")

    # A magic byte (expressed as two hexadecimal nibbles) which we use to
    # expand a 15-byte ElasticSearch flake ID into a 16-byte UUID.
    #
    # The UUID specification defines UUIDs as taking the form
    #
    #     xxxxxxxx-xxxx-Mxxx-Nxxx-xxxxxxxxxxxx
    #
    # in the canonical hexadecimal representation. M and N represent the UUID
    # version and variant fields respectively. The four bits M can take values
    # {1, 2, 3, 4, 5} in specified UUID types, and the first three bits of N
    # can take the values {8, 9, 0xa, 0xb} in specified UUID types.
    #
    # In order to expand a 15-byte ElasticSearch flake ID into a value that can
    # be stored in the UUID field, we insert the magic nibbles 0xe, 0x5 into
    # the version and variant fields respectively. These values are disjoint
    # with any specified UUID so the resulting UUID can be distinguished from
    # those generated by, for example, PostgreSQL's uuid_generate_v1mc(), and
    # mapped back to a 20-char ElasticSearch flake ID.
    _MAGIC_BYTE = ["e", "5"]

    @classmethod
    def _has_magic_byte(cls, hex_str):
        return hex_str[12] == cls._MAGIC_BYTE[0] and hex_str[16] == cls._MAGIC_BYTE[1]

    @classmethod
    def _add_magic_byte(cls, hex_str):
        # Convert to UUIDs by inserting the magic nibbles
        nibble_12, nibble_16 = cls._MAGIC_BYTE
        return hex_str[:12] + nibble_12 + hex_str[12:15] + nibble_16 + hex_str[15:]

    @classmethod
    def _remove_magic_byte(cls, hex_str):
        # The hex representation of the flake ID is simply the UUID without the
        # two magic nibbles.
        return hex_str[0:12] + hex_str[13:16] + hex_str[17:32]


class AnnotationSelectorJSONB(
    types.TypeDecorator
):  # pylint:disable=abstract-method, too-many-ancestors
    r"""
    Special type for the Annotation selector column.

    It transparently escapes NULL (\u0000) bytes to \\u0000 when writing to the
    database, and the other way around when reading from the database, but
    only on the prefix/exact/suffix fields in a TextQuoteSelector.
    """

    impl = postgresql.JSONB
    cache_ok = False

    def process_bind_param(self, value, dialect):
        return _transform_quote_selector(value, _escape_null_byte)

    def process_result_value(self, value, dialect):
        return _transform_quote_selector(value, _unescape_null_byte)


def _transform_quote_selector(selectors, transform_func):  # pragma: no cover
    if selectors is None:
        return None

    if not isinstance(selectors, list):
        return selectors

    for selector in selectors:
        if not isinstance(selector, dict):
            continue

        if not selector.get("type") == "TextQuoteSelector":
            continue

        if "prefix" in selector:
            selector["prefix"] = transform_func(selector["prefix"])
        if "exact" in selector:
            selector["exact"] = transform_func(selector["exact"])
        if "suffix" in selector:
            selector["suffix"] = transform_func(selector["suffix"])

    return selectors


def _escape_null_byte(string):
    if string is None:
        return string  # pragma: no cover

    return string.replace("\u0000", "\\u0000")


def _unescape_null_byte(string):
    if string is None:
        return string  # pragma: no cover

    return string.replace("\\u0000", "\u0000")
